Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8282: Support non-4d tensor and fp32_dest_acc_en for moreh nllloss backward #8966

Merged
merged 15 commits into from
Jun 3, 2024

Conversation

hschoi4448
Copy link
Contributor

  • refactoring moreh nllloss backward

    • support non 4d tensor input
    • fp32_dest_acc_en
  • Add moreh helper functions for fp32_dest_acc_en

    • refactoring moreh_nll_loss forward kernels


ALWI void ACQ() { acquire_dst(tt::DstMode::Half); }
ALWI void REL() { release_dst(tt::DstMode::Half); }
#include "debug/dprint.h" // required in all kernels using DPRINT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to remove this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the dprint header.

Comment on lines 53 to 58
union {
float f;
uint32_t u;
} one, zero;
one.f = 1.0f;
zero.f = 0.0f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now use the Scalar structure

Comment on lines 333 to 347
for _ in range(2):
tt_input_grad = ttl.operations.primary.moreh_nll_loss_backward(
tt_target,
tt_weight,
tt_divisor,
tt_output_grad,
tt_input_grad,
ignore_index,
reduction_mean,
)
Copy link
Contributor

@TT-BrianLiu TT-BrianLiu May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of this loop? If it's to test program cache:

  • add an assert on program entries
  • shift the input/output memory by adding some dummy tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of this loop? If it's to test program cache:

Yes, this is to test the program cache. If it runs only once, the code related to override_runtime_args_callback will not be executed.

add an assert on program entries

I didn't understand the previous statement.
There are asserts like:

            TT_ASSERT(input_tensors.size() == 2);
            TT_ASSERT(optional_input_tensors.size() == 2);
            TT_ASSERT(output_tensors.size() == 1);

If not, can you provide an example?

shift the input/output memory by adding some dummy tensor

I've modified the callback test to receive random input every time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to test program cache:

  • Ensure that program cache is actually hit. You do this by checking that the number of generated caches for an op is 1 (or however many your test generates) when you loop it twice. You can query the number of program caches with device.num_program_cache_entries().
  • Ensure that the callback is actually correct. To properly test the callback, you have to make sure that the runtime arg updates actually matter in your test. Most often, the callback updates the inputs/output buffer addresses. Trivially looping the test most often results in your input/output tensors being in the same location for two runs (ie. same buffer addresses). If this is the case, the test will always pass when you don't have anything in the callback, even if your data is different the second time. To convince yourself, you can print out the value of the args being updated in the callback and see if it's the same as the first time the program was compiled and launched. If it is, then the test isn't really testing anything. One hack is to shift the device memory of where your inputs/output are expected to be by allocating a small tensor in between the loops (see below). As an exercise, you can again comment out the callback and now you will see the test actually fail the second time since your inputs/output are actually in a different location now.
image

Copy link
Contributor Author

@hschoi4448 hschoi4448 Jun 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trivially looping the test most often results in your input/output tensors being in the same location for two runs (ie. same buffer addresses).

Thank you for the comment. As you mentioned, by running the for loop twice without tensor allocation, I confirmed that the tensor addresses remained the same. Therefore, I moved the tensor creation inside the for loop, and I observed that the addresses were changed. Since the buffer addresses change without a dummy tensor, it seems the dummy tensor is unnecessary.

This has benn fixed in 'f3c5ab9'

    for _ in range(2):
        # In each loop, a new tt tensor and value are created.
        (torch_input, torch_target, torch_weight, torch_divisor, torch_output) = get_torch_tensors(shape)
        if none_weight:
            torch_weight = None

        (tt_input, tt_target, tt_weight, tt_divisor, tt_output) = get_tt_tensors(
            torch_input, torch_target, torch_weight, torch_divisor, torch_output, device
        )

        tt_loss = ttl.operations.primary.moreh_nll_loss(
            tt_input,
            tt_target,
            tt_weight,
            tt_divisor,
            tt_output,
            ignore_index,
            reduction_mean,
        )

device.num_program_cache_entries().

The above function appears to print the number of generated caches.
However, in the case of NLL loss, it internally calls moreh_sum, and the number of generated caches depends on the implementation of moreh_sum. Therefore, checking this number might not be the correct approach.

Instead of the num_program_cache_entries function, how about adding a boolean variable to ProgramCache along with enable_cache_check() and disable_cache_check() functions, and then incorporating checks like TT_ASSERT(cache_hit)?

struct ProgramCache {
    inline std::optional<std::shared_ptr<void>> find(uint64_t program_hash) {
        auto cache_hit = this->cache_.count(program_hash) > 0;
        if (is_cache_check_enabled_ ) {
           TT_ASSERT(cache_hit);
        }
        if (cache_hit) {
            return this->cache_.at(program_hash);
        }
        return std::nullopt;
    }

    void enable_cache_check() {
        is_cache_check_enabled_ = true;
    }

    void disable_cache_check() {
        is_cache_check_enabled_ = false;
    }

  private:
    inline static bool is_cache_check_enabled_ = false;
}
def test_callback()
    ...
    for i in range(2):
        if (i == 1) 
            # After enabling cache_check, if a cache miss occurs, an assertion is triggered.
            device.enable_cache_check()

       run_tt_op()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your method without dummy tensors probably also works because python doesn't deallocate the old tt tensors when it creates the next set.

I don't think we need to add anything new to ProgramCache. It doesn't matter which implementation of moreh_sum this test uses. You just have to assert against the number of expected caches.

@hschoi4448 hschoi4448 force-pushed the hyungsuk/moreh_nllloss_backward_refactoring branch from 1fb8d3f to d3b2857 Compare June 3, 2024 23:41
@hschoi4448 hschoi4448 merged commit 4edfbe0 into main Jun 3, 2024
5 checks passed
@hschoi4448 hschoi4448 deleted the hyungsuk/moreh_nllloss_backward_refactoring branch June 3, 2024 23:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
moreh moreh contribution
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants